{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "matchzoo version 1.0\n",
      "`ranking_task` initialized with metrics [normalized_discounted_cumulative_gain@3(0.0), normalized_discounted_cumulative_gain@5(0.0), mean_average_precision(0.0)]\n",
      "data loading ...\n",
      "data loaded as `train_pack_raw` `dev_pack_raw` `test_pack_raw`\n"
     ]
    }
   ],
   "source": [
    "%run init.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "ranking_task = mz.tasks.Ranking(losses=mz.losses.RankCrossEntropyLoss(num_neg=4))\n",
    "ranking_task.metrics = [\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=3),\n",
    "    mz.metrics.NormalizedDiscountedCumulativeGain(k=5),\n",
    "    mz.metrics.MeanAveragePrecision()\n",
    "]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 10688.44it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:02<00:00, 6461.84it/s]\n",
      "Processing text_right with append: 100%|██████████| 18841/18841 [00:00<00:00, 1024195.57it/s]\n",
      "Building FrequencyFilter from a datapack.: 100%|██████████| 18841/18841 [00:00<00:00, 156514.98it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 150658.17it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 660633.29it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 716006.14it/s]\n",
      "Building Vocabulary from a datapack.: 100%|██████████| 418401/418401 [00:00<00:00, 3113105.21it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 57316.46it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 20611.73it/s]\n",
      "Processing text_left with extend: 100%|██████████| 2118/2118 [00:00<00:00, 512699.02it/s]\n",
      "Processing text_right with extend: 100%|██████████| 18841/18841 [00:00<00:00, 384714.05it/s]\n",
      "Building Vocabulary from a datapack.: 100%|██████████| 2082963/2082963 [00:00<00:00, 4107566.35it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2118/2118 [00:00<00:00, 11845.98it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 18841/18841 [00:02<00:00, 6507.02it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 146201.59it/s]\n",
      "Processing text_left with transform: 100%|██████████| 2118/2118 [00:00<00:00, 220123.79it/s]\n",
      "Processing text_right with transform: 100%|██████████| 18841/18841 [00:00<00:00, 132351.14it/s]\n",
      "Processing length_left with len: 100%|██████████| 2118/2118 [00:00<00:00, 648054.85it/s]\n",
      "Processing length_right with len: 100%|██████████| 18841/18841 [00:00<00:00, 813967.84it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 122/122 [00:00<00:00, 9048.56it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 1115/1115 [00:00<00:00, 6184.72it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 160428.42it/s]\n",
      "Processing text_left with transform: 100%|██████████| 122/122 [00:00<00:00, 132806.93it/s]\n",
      "Processing text_right with transform: 100%|██████████| 1115/1115 [00:00<00:00, 127165.79it/s]\n",
      "Processing length_left with len: 100%|██████████| 122/122 [00:00<00:00, 182426.06it/s]\n",
      "Processing length_right with len: 100%|██████████| 1115/1115 [00:00<00:00, 594388.53it/s]\n",
      "Processing text_left with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 237/237 [00:00<00:00, 10068.78it/s]\n",
      "Processing text_right with chain_transform of Tokenize => Lowercase => PuncRemoval: 100%|██████████| 2300/2300 [00:00<00:00, 6338.85it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 160199.59it/s]\n",
      "Processing text_left with transform: 100%|██████████| 237/237 [00:00<00:00, 161005.84it/s]\n",
      "Processing text_right with transform: 100%|██████████| 2300/2300 [00:00<00:00, 121958.27it/s]\n",
      "Processing length_left with len: 100%|██████████| 237/237 [00:00<00:00, 250327.39it/s]\n",
      "Processing length_right with len: 100%|██████████| 2300/2300 [00:00<00:00, 652832.05it/s]\n"
     ]
    }
   ],
   "source": [
    "preprocessor = mz.models.DSSM.get_default_preprocessor(ngram_size=3)\n",
    "train_pack_processed = preprocessor.fit_transform(train_pack_raw)\n",
    "valid_pack_processed = preprocessor.transform(dev_pack_raw)\n",
    "test_pack_processed = preprocessor.transform(test_pack_raw)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'ngram_process_unit': <matchzoo.preprocessors.units.ngram_letter.NgramLetter at 0x13b274e10>,\n",
       " 'filter_unit': <matchzoo.preprocessors.units.frequency_filter.FrequencyFilter at 0x10b1f3550>,\n",
       " 'vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x10b71ea20>,\n",
       " 'vocab_size': 30058,\n",
       " 'embedding_input_dim': 30058,\n",
       " 'ngram_vocab_unit': <matchzoo.preprocessors.units.vocabulary.Vocabulary at 0x13e2a7f98>,\n",
       " 'ngram_vocab_size': 9654}"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "preprocessor.context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "triletter_callback = mz.dataloader.callbacks.Ngram(\n",
    "    preprocessor, mode='aggregate')\n",
    "\n",
    "trainset = mz.dataloader.Dataset(\n",
    "    data_pack=train_pack_processed,\n",
    "    mode='pair',\n",
    "    num_dup=1,\n",
    "    num_neg=4,\n",
    "    callbacks=[triletter_callback]\n",
    ")\n",
    "testset = mz.dataloader.Dataset(\n",
    "    data_pack=test_pack_processed,\n",
    "    callbacks=[triletter_callback]\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "padding_callback = mz.models.DSSM.get_default_padding_callback()\n",
    "\n",
    "trainloader = mz.dataloader.DataLoader(\n",
    "    dataset=trainset,\n",
    "    batch_size=32,\n",
    "    stage='train',\n",
    "    resample=True,\n",
    "    callback=padding_callback\n",
    ")\n",
    "testloader = mz.dataloader.DataLoader(\n",
    "    dataset=testset,\n",
    "    batch_size=32,\n",
    "    stage='dev',\n",
    "    callback=padding_callback\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "DSSM(\n",
      "  (mlp_left): Sequential(\n",
      "    (0): Sequential(\n",
      "      (0): Linear(in_features=9654, out_features=300, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "    (1): Sequential(\n",
      "      (0): Linear(in_features=300, out_features=300, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "    (2): Sequential(\n",
      "      (0): Linear(in_features=300, out_features=300, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "    (3): Sequential(\n",
      "      (0): Linear(in_features=300, out_features=128, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "  )\n",
      "  (mlp_right): Sequential(\n",
      "    (0): Sequential(\n",
      "      (0): Linear(in_features=9654, out_features=300, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "    (1): Sequential(\n",
      "      (0): Linear(in_features=300, out_features=300, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "    (2): Sequential(\n",
      "      (0): Linear(in_features=300, out_features=300, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "    (3): Sequential(\n",
      "      (0): Linear(in_features=300, out_features=128, bias=True)\n",
      "      (1): ReLU()\n",
      "    )\n",
      "  )\n",
      "  (out): Linear(in_features=1, out_features=1, bias=True)\n",
      ") 6231258\n"
     ]
    }
   ],
   "source": [
    "model = mz.models.DSSM()\n",
    "\n",
    "model.params['task'] = ranking_task\n",
    "model.params['vocab_size'] = preprocessor.context['ngram_vocab_size']\n",
    "model.params['mlp_num_layers'] = 3\n",
    "model.params['mlp_num_units'] = 300\n",
    "model.params['mlp_num_fan_out'] = 128\n",
    "model.params['mlp_activation_func'] = 'relu'\n",
    "\n",
    "model.build()\n",
    "\n",
    "print(model, sum(p.numel() for p in model.parameters() if p.requires_grad))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = torch.optim.Adam(model.parameters())\n",
    "\n",
    "trainer = mz.trainers.Trainer(\n",
    "    model=model,\n",
    "    optimizer=optimizer,\n",
    "    trainloader=trainloader,\n",
    "    validloader=testloader,\n",
    "    validate_interval=None,\n",
    "    epochs=10\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3022b6510f204f7fba9f8fdf0abdf037",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-32 Loss-1.521]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.5219 - normalized_discounted_cumulative_gain@5(0.0): 0.5864 - mean_average_precision(0.0): 0.5432\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a8df2a3c656d4164b7a913397ad0ef85",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-64 Loss-1.343]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4888 - normalized_discounted_cumulative_gain@5(0.0): 0.5738 - mean_average_precision(0.0): 0.5166\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5e4dfcb292cc404d89fa26ea1dbe5122",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-96 Loss-1.320]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4447 - normalized_discounted_cumulative_gain@5(0.0): 0.5118 - mean_average_precision(0.0): 0.4652\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "29296f0d9ae740f1b9dc2b74a5b72849",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-128 Loss-1.262]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4977 - normalized_discounted_cumulative_gain@5(0.0): 0.5674 - mean_average_precision(0.0): 0.5216\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "af7fa97a5fbb4171a07149d57501461a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-160 Loss-1.213]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4819 - normalized_discounted_cumulative_gain@5(0.0): 0.5456 - mean_average_precision(0.0): 0.4997\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a03bc789d05a4ec993dbcfe8bf6b55c8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-192 Loss-1.186]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4921 - normalized_discounted_cumulative_gain@5(0.0): 0.5655 - mean_average_precision(0.0): 0.5069\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "ce02df612a72431f808af6a441472e68",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-224 Loss-1.177]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4859 - normalized_discounted_cumulative_gain@5(0.0): 0.5583 - mean_average_precision(0.0): 0.5056\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "71fa98f2fb854d008c5cd8fcef3b05d3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-256 Loss-1.137]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4789 - normalized_discounted_cumulative_gain@5(0.0): 0.5499 - mean_average_precision(0.0): 0.4926\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3c0cbb5b90b1495eb0937320c1bef6f6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-288 Loss-1.126]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4967 - normalized_discounted_cumulative_gain@5(0.0): 0.5725 - mean_average_precision(0.0): 0.5274\n",
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "66f27dd85e044371b831ed7e511a3880",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=32), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[Iter-320 Loss-1.092]:\n",
      "  Validation: normalized_discounted_cumulative_gain@3(0.0): 0.4868 - normalized_discounted_cumulative_gain@5(0.0): 0.55 - mean_average_precision(0.0): 0.506\n",
      "\n",
      "Cost time: 1625.7705359458923s\n"
     ]
    }
   ],
   "source": [
    "trainer.run()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
